from tqdm import tqdm
from utils.utils import call_model, remove_extra_target_occurrences, able_to_quit, retrieve_facts, \
    get_sent_embeddings
from mquake_dataset import MQUAKE
import json


def mello_eval_loop(mquake_dataset, task_prompt, sc_fact, rand_list, model, model_tokenizer, device,
                    contriever, tokenizer, print_prompt, masking, logger, result_file_path, pre_token_length=4):
    # Run MeLLo on the first T (T=200) examples
    cor = 0
    h_cor = 0  # hop correctness, following PokeMQA
    tot = 0
    raw_answer_dict = {}
    for d in mquake_dataset.get_dataset():
        edit_flag = d['case_id'] in rand_list
        raw_answer_dict[d['case_id']] = {'edited': edit_flag}
        if masking:
            new_facts, _, _, _ = mquake_dataset.get_edits_without_contamination(rand_list, d)
            if not new_facts:
                new_facts = ["No relevant fact."]
            embs = get_sent_embeddings(new_facts, contriever, tokenizer)
        else:
            new_facts = set()
            for d in mquake_dataset.get_dataset():
                if d['case_id'] not in rand_list:
                    continue
                for r in d["requested_rewrite"]:
                    new_facts.add(f'{r["prompt"].format(r["subject"])} {r["target_new"]["str"]}')
            new_facts = list(new_facts)
            if not new_facts:
                new_facts = ["No relevant fact."]
            embs = get_sent_embeddings(new_facts, contriever, tokenizer)
        
        tot += 1
        llm_answers = []
        
        for q in d["questions"]:
            found_ans = False
            prompt = task_prompt + "\n\nQuestion: " + q
            ans = None
            intermediate_answer = None
            
            for i in range(4):  # max of 4 hops
                # prompt the model to generate a subquestion and a tentative answer
                prompt = call_model(prompt, sc_fact, model, model_tokenizer, device)
                if prompt.strip().split('\n')[-1] == 'Retrieved fact:':
                    prompt = prompt[:-len('\nRetrieved fact:')]
                prompt = remove_extra_target_occurrences(prompt, "Question: ", 5)[pre_token_length:]
                
                # if final answer is there, get the answer and exit
                quit, ans = able_to_quit(prompt, task_prompt)
                if quit:
                    found_ans = True
                    break
                
                temp_split = prompt.strip().split('\n')
                # otherwise, extract the generated subquestion
                if len(temp_split) < 2:
                    break  # failed case
                
                subquestion = temp_split[-2]
                
                if not subquestion.startswith('Subquestion: '):
                    break  # failed case
                
                if rand_list:
                    fact_ids = retrieve_facts(subquestion, embs, contriever, tokenizer)
                    fact_sent = new_facts[fact_ids[0]]
                    
                    # put the retrieved fact at the end of the prompt, the model self-checks if it contradicts
                    prompt = prompt + '\nRetrieved fact: ' + fact_sent + "\n"
                
                quit, ans = able_to_quit(prompt, task_prompt)
                if quit:
                    found_ans = True
                    break
            
            if print_prompt:
                logger.info(prompt[len(task_prompt):] + "\n")
            if not found_ans:
                continue
            llm_answers.append(ans)
            # acc:
            if mquake_dataset.check_answer(edit_flag, d, ans):
                cor += 1
                if mquake_dataset.verify_subquestion_path(prompt[len(task_prompt):], d, edit_flag):
                    h_cor += 1
                break
        raw_answer_dict[d['case_id']]['answers'] = llm_answers

        if tot % 10 == 0:
            with open(result_file_path, 'w') as f:
                json.dump(raw_answer_dict, f)

        logger.info("%s (%s), %s" % (cor, h_cor, tot))

    logger.info(f'Multi-hop acc = {cor / tot} ({cor} / {tot})')

    with open(result_file_path, 'w') as f:
        json.dump(raw_answer_dict, f)
